import os
import sys
import gzip
import math
import tempfile
import shutil
from collections import defaultdict
import pybedtools
import pysam
from Bio.Seq import reverse_complement


dataset = sys.argv[1]
library = sys.argv[2]


distal_distance = 2000

skip_targets = ('chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'yRNA',
                'scaRNA', 'snar', 'vRNA', 'RMRP', 'RPPH',
               )

keep_targets = ('TERC', 'MALAT1', 'snhg', 'histone',
                'mRNA', 'lncRNA', 'gencode', 'fantomcat', 'novel', 'genome')


annotated = ('presnRNA', 'pretRNA', 'presnoRNA', 'prescaRNA',
             'FANTOM5_enhancer', 'roadmap_enhancer', 'roadmap_dyadic',
             'novel_enhancer_HiSeq', "novel_enhancer_CAGE",
            )

preferred_annotations = ('sense_proximal',
                         'sense_upstream',
                         'sense_distal',
                         'sense_distal_upstream',
                         'prompt',
                         'antisense',
                         'antisense_distal',
                         'antisense_distal_upstream')

def find_preferred_annotation(overlap):
    for preferred_annotation in preferred_annotations:
        for index in overlap:
            for gene_name, annotation, direction, distance in overlap[index]:
                if annotation == preferred_annotation:
                    return annotation
    raise Exception("Failed to find annotation '%s'", annotation)

def select_preferred_annotation(overlap):
    preferred_annotation = find_preferred_annotation(overlap)
    indices = list(overlap.keys())
    for index in indices:
        kept = []
        for gene_name, annotation, direction, distance in overlap[index]:
            if annotation == preferred_annotation:
                kept.append([gene_name, annotation, direction, distance])
        if kept:
            overlap[index] = kept
        else:
            del overlap[index]

def select_closest_overlap(overlap):
    closest_distance = math.inf
    indices = list(overlap.keys())
    for index in indices:
        for gene_name, annotation, direction, distance in overlap[index]:
            if abs(distance) < closest_distance:
                closest_distance = abs(distance)
    for index in indices:
        kept = []
        for gene_name, annotation, direction, distance in overlap[index]:
            if abs(distance) == closest_distance:
                kept.append([gene_name, annotation, direction, distance])
        if kept:
            gene_name, annotation, direction, distance = kept[0]
            gene_names = [gene_name]
            for row in kept[1:]:
                assert row[1:] == [annotation, direction, distance]
                gene_name = row[0]
                gene_names.append(gene_name)
            assert len(gene_names) == len(set(gene_names))
            gene_name = ",".join(gene_names)
            overlap[index] = (gene_name, annotation, distance)
        else:
            del overlap[index]

def find_gene_associations(genes, sequences):
    overlap = defaultdict(lambda: defaultdict(list))
    for direction in ('sense', 'antisense'):
        if direction == 'sense':
            lines = genes.window(sequences, sm=True, l=distal_distance, r=0, sw=True)
        elif direction == 'antisense':
            lines = genes.window(sequences, Sm=True, l=distal_distance, r=0, sw=True)
        else:
            raise Exception("Unknown direction '%s'; should be 'sense' or 'antisense'")
        for line in lines:
            fields = line.fields
            gene = pybedtools.create_interval_from_list(fields[:9])
            strand = gene.strand
            sequence = pybedtools.create_interval_from_list(fields[9:])
            name = sequence.name
            if direction == 'sense':
                assert sequence.strand == strand
            elif direction == 'antisense':
                assert sequence.strand != strand
            # sequence.start is its 5' end; sequence.end = sequence.start + 1
            assert sequence.end == sequence.start + 1
            tss = int(gene.attrs['TSS'])
            if strand == '+':
                distance = sequence.start - tss
            elif strand == '-':
                distance = tss - sequence.start
            else:
                raise Exception("Unknown strand")
            if direction == 'sense':
                if distance < -distal_distance:
                    annotation = 'sense_distal_upstream'
                elif distance < 0:
                    annotation = 'sense_upstream'
                elif distance <= distal_distance:
                    annotation = 'sense_proximal'
                else:
                    annotation = 'sense_distal'
            else:
                if distance < -distal_distance:
                    annotation = 'antisense_distal_upstream'
                elif distance < 0:
                    annotation = 'prompt'
                elif distance <= distal_distance:
                    annotation = 'antisense'
                else:
                    annotation = 'antisense_distal'
            index = int(sequence.score)
            overlap[name][index].append([gene.name, annotation, direction, distance])
    for name in overlap:
        select_preferred_annotation(overlap[name])
        select_closest_overlap(overlap[name])
        overlap[name] = dict(overlap[name])
    overlap = dict(overlap)
    return overlap

def read_unannotated_sequences(dataset, library):
    names = []
    filename = "%s.bam" % library
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Mapping" % dataset
    path = os.path.join(directory, filename)
    print("Reading", path)
    bamfile = pysam.AlignmentFile(path)
    for line in bamfile:
        if dataset == "MiSeq":
            line2 = next(bamfile)
        if line.is_unmapped:
            if dataset == "MiSeq":
                assert line2.is_unmapped
            continue
        target = line.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            annotation = line.get_tag("XA")
        except KeyError:
            pass
        else:
            if annotation in annotated:
                continue
            assert annotation in preferred_annotations
        if line.is_reverse:
            strand = '-'
            if dataset == "MiSeq":
                assert not line2.is_reverse
                start1 = line.reference_start
                start2 = line2.reference_start
                end1 = line.reference_end
                end2 = line2.reference_end
                assert start1 > start2
                assert end1 > end2
                tss = end1 - 1
            else:
                tss = line.reference_end - 1
        else:
            strand = '+'
            if dataset == "MiSeq":
                assert line2.is_reverse
                start1 = line.reference_start
                start2 = line2.reference_start
                end1 = line.reference_end
                end2 = line2.reference_end
                assert start1 < start2
                assert end1 < end2
                tss = start1
            else:
                tss = line.reference_start
        chromosome = line.reference_name
        name = line.qname
        try:
            index = int(line.get_tag("HI"))
        except KeyError:
            raise ValueError(name)
        fields = [chromosome, tss, tss+1, name, str(index), strand]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval
    bamfile.close()

def write_annotations(dataset, library, associations):
    if dataset == 'MiSeq':
        prefix = 'M00528:115:000000000-AD0NU:1:'
    elif dataset == 'HiSeq':
        prefix = 'HWI-ST554:305:C91WGACXX:'
    elif dataset == 'CAGE':
        if library == "00_hr_A":
            prefix = "D00261:355:C9DA7ANXX:1"
        elif library == "00_hr_C":
            prefix = "D00261:355:C9DA7ANXX:1"
        elif library == "00_hr_G":
            prefix = "D00261:373:CAB5TANXX:7"
        elif library == "00_hr_H":
            prefix = "D00261:389:CAG09ANXX:6"
        elif library == "01_hr_A":
            prefix = "D00261:389:CAG09ANXX:4"
        elif library == "01_hr_C":
            prefix = "D00261:389:CAG09ANXX:5"
        elif library == "01_hr_G":
            prefix = "D00261:389:CAG09ANXX:5"
        elif library == "04_hr_C":
            prefix = "D00261:373:CAB5TANXX:7"
        elif library == "04_hr_E":
            prefix = "D00261:373:CAB5TANXX:8"
        elif library == "12_hr_A":
            prefix = "D00261:355:C9DA7ANXX:1"
        elif library == "12_hr_C":
            prefix = "D00261:355:C9DA7ANXX:1"
        elif library == "24_hr_C":
            prefix = "D00261:373:CAB5TANXX:7"
        elif library == "24_hr_E":
            prefix = "D00261:373:CAB5TANXX:7"
        elif library == "96_hr_A":
            prefix = "D00261:398:CAHCHANXX:3"
        elif library == "96_hr_C":
            prefix = "D00261:398:CAHCHANXX:3"
        elif library == "96_hr_E":
            prefix = "D00261:398:CAHCHANXX:3"
        else:
            raise Exception("Unknown library %s" % library)
    elif dataset == 'StartSeq':
        prefix = library + "_"
    else:
        raise Exception("Unknown dataset %s" % dataset)
    filename = "%s.bam" % library
    directory = os.path.join("/osc-fs_home/mdehoon/Data/CASPARs/", dataset, "Mapping")
    path = os.path.join(directory, filename)
    print("Reading", path)
    lines = pysam.AlignmentFile(path, "rb")
    print("Writing", filename)
    output = pysam.AlignmentFile(filename, "wb", template=lines)
    current = None
    for line1 in lines:
        name = line1.query_name
        assert name.startswith(prefix)
        if dataset == "MiSeq":
            line2 = next(lines)
            assert line2.query_name == line1.query_name
        if name != current:
            if current in associations:
                assert len(associations[current]) == new_index
            current = name
            old_index = 0
            new_index = 0
            association = associations.get(name)
            assert not line1.is_secondary
            if line1.is_reverse:
                sequence1 = reverse_complement(line1.query_sequence)
            else:
                sequence1 = line1.query_sequence
            if dataset == "MiSeq":
                if line2.is_reverse:
                    sequence2 = reverse_complement(line2.query_sequence)
                else:
                    sequence2 = line2.query_sequence
        else:
            old_index += 1
            assert line1.is_secondary
            if dataset == "MiSeq":
                assert line2.is_secondary
        if not line1.is_unmapped:
            assert old_index == line1.get_tag("HI")
            if association:
                multimap = len(association)
                chromosome = line1.reference_name
                if line1.is_reverse:
                    strand = "-"
                    start = line1.reference_end - 1
                else:
                    strand = '+'
                    start = line1.reference_start
                overlap = association.get(old_index)
                if not overlap:
                    continue
                gene_name, annotation, distance = overlap
                assert gene_name != "."
                try:
                    line1.get_tag("XA")
                except KeyError:
                    pass
                else:
                    raise Exception("found existing annotation tag XA")
                try:
                    tag = line1.get_tag("XG")
                except KeyError:
                    pass
                else:
                    print(gene_name, annotation, tag)
                    raise Exception("found existing gene tag XG for %s" % name)
                line1.set_tag("NH", multimap)
                line1.set_tag("HI", new_index)
                line1.set_tag("XG", gene_name)
                line1.set_tag("XA", annotation)
                line1.set_tag("XD", distance)
                if new_index == 0:
                    line1.is_secondary = False
                    if line1.is_reverse:
                        line1.query_sequence = reverse_complement(sequence1)
                    else:
                        line1.query_sequence = sequence1
                    if dataset == "MiSeq":
                        line2.is_secondary = False
                        if line1.is_reverse:
                            line2.query_sequence = reverse_complement(sequence2)
                        else:
                            line2.query_sequence = sequence2
                else:
                    assert line1.is_secondary
                    if dataset == "MiSeq":
                        assert line2.is_secondary
        output.write(line1)
        if dataset == "MiSeq":
            output.write(line2)
        new_index += 1
    if current in associations:
        assert len(associations[current]) == new_index
    output.close()
    lines.close()


filename = "genes.FANTOM_CAT.THP-1.gff"
print("Reading gene information from %s" % filename)
genes = pybedtools.BedTool(filename)


sequences = read_unannotated_sequences(dataset, library)
sequences = pybedtools.BedTool(sequences)
sequences = sequences.saveas()
sequences = sequences.sort()
sequences = sequences.saveas()
print("Finding gene associations")
associations = find_gene_associations(genes, sequences)
write_annotations(dataset, library, associations)
